{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d116b97f",
   "metadata": {},
   "outputs": [],
   "source": [
    "# CUDA_VISIBLE_DEVICES=\"4,5,6\" vllm/bin/vllm serve \"mesolitica/Malaysian-Qwen2.5-7B-Dialect-Reasoning-GRPO\" --dtype float32 --port 8007 --tensor_parallel_size=3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4d8ac425",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "bacaeeef",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = load_dataset('huseinzol05/malaysian-dialect-qa', split = 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e05b4076",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/lib/python3/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.17.3 and <1.25.0 is required for this version of SciPy (detected version 1.26.4\n",
      "  warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n"
     ]
    }
   ],
   "source": [
    "dataset_lang = load_dataset('huseinzol05/malaysian-dialect-qa-lang', split = 'test')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "937e7b2e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "140"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "questions = []\n",
    "for i in range(len(dataset)):\n",
    "    q = dataset[i]['question']\n",
    "    questions.append((i, q))\n",
    "    \n",
    "len(questions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "49453d32",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mkdir: cannot create directory ‘Malaysian-Qwen2.5-7B-Reasoning-GRPO-fp32’: File exists\r\n"
     ]
    }
   ],
   "source": [
    "folder = 'Malaysian-Qwen2.5-7B-Reasoning-GRPO-fp32'\n",
    "# !rm -rf {folder}\n",
    "!mkdir {folder}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "20b32ea4",
   "metadata": {},
   "outputs": [],
   "source": [
    "import requests\n",
    "import os\n",
    "import json\n",
    "import re\n",
    "\n",
    "def generate_answer(row, repeat = 5):\n",
    "    no, q = row\n",
    "    for k in range(repeat):\n",
    "        filename = os.path.join(folder, f'{no}-{k}.json')\n",
    "        try:\n",
    "            with open(filename) as fopen:\n",
    "                json.load(fopen)\n",
    "            continue\n",
    "        except:\n",
    "            pass\n",
    "\n",
    "        json_data = {\n",
    "            'model': \"mesolitica/Malaysian-Qwen2.5-7B-Dialect-Reasoning-GRPO\",\n",
    "            'messages': [\n",
    "                {'role': 'system', 'content': 'You are going to enter reasoning mode. First, you try to think step-by-step in Malay. After that, put your final answer within $\\\\boxed{}$.'},\n",
    "                {'role': 'user', 'content': q},\n",
    "            ],\n",
    "            'max_tokens': 24000,\n",
    "        }\n",
    "        \n",
    "        while True:\n",
    "            response = requests.post('http://localhost:8007/v1/chat/completions', json=json_data)\n",
    "            r = response.json()['choices'][0]['message']['content'].strip()\n",
    "            answers = re.findall(r\"\\$boxed\\{(.*?)\\}\\$\", r)\n",
    "            if len(answers) == 1:\n",
    "                a = answers[0]\n",
    "                with open(filename, 'w') as fopen:\n",
    "                    json.dump(a, fopen)\n",
    "                    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "07ab3ae3",
   "metadata": {},
   "outputs": [],
   "source": [
    "generate_answer(questions[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "588c45e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "def consumer(queue, name):\n",
    "    while True:\n",
    "        if queue.qsize() == 0:\n",
    "            break\n",
    "        item = queue.get()\n",
    "        generate_answer(item)\n",
    "    print(f'consumer {name} done')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "c5035c80",
   "metadata": {},
   "outputs": [],
   "source": [
    "from threading import Thread\n",
    "from queue import Queue\n",
    "\n",
    "queue = Queue()\n",
    "for u in questions:\n",
    "    queue.put(u)\n",
    "    \n",
    "ori_size = queue.qsize()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "7c25154c",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 99%|████████████████████████████████████████████████████████████████████████████████████████████████████████▎| 139/140 [06:36<00:02,  2.86s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "consumer 37 done\n",
      "consumer 45 done\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "max_worker = 50\n",
    "consumers = [Thread(target=consumer, args=(queue,i)) for i in range(max_worker)]\n",
    "for i in range(len(consumers)):\n",
    "    consumers[i].start()\n",
    "    \n",
    "pbar = tqdm(total=ori_size)\n",
    "last_size = 0\n",
    "while True:\n",
    "    size = queue.qsize()\n",
    "    if size == 0:\n",
    "        break\n",
    "    left = ori_size - size\n",
    "    minus = left - last_size\n",
    "    if minus > 0:\n",
    "        pbar.update(minus)\n",
    "        last_size += minus\n",
    "\n",
    "pbar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "11c4fb46",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "consumer 15 done\n",
      "consumer 42 done\n"
     ]
    }
   ],
   "source": [
    "from sacrebleu.metrics import CHRF\n",
    "from glob import glob\n",
    "from collections import defaultdict\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "9310cb8a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 140/140 [00:00<00:00, 782.52it/s]\n"
     ]
    }
   ],
   "source": [
    "chrf = CHRF()\n",
    "pairs = defaultdict(list)\n",
    "\n",
    "for i in tqdm(range(len(dataset_lang))):\n",
    "    from_lang = dataset_lang[i]['from_lang']\n",
    "    to_lang = dataset_lang[i]['to_lang']\n",
    "    gt = dataset_lang[i]['answer']\n",
    "    pair = f'{from_lang}<>{to_lang}'\n",
    "    files = glob(f'Malaysian-Qwen2.5-7B-Reasoning-GRPO-fp32/{i}-*.json')\n",
    "    if len(files) < 5:\n",
    "        print(i, len(files))\n",
    "    scores = []\n",
    "    for f in files:\n",
    "        with open(f) as fopen:\n",
    "            d = json.load(fopen)\n",
    "        score = chrf.corpus_score([d], [[gt]]).score\n",
    "        scores.append(score)\n",
    "\n",
    "    max_score = max(scores)\n",
    "    pairs[pair].append(max_score)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "181968c0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "From: johor To: malay, score: 58.2189619529139\n",
      "From: kedah To: malay, score: 59.21260384746205\n",
      "From: pahang To: malay, score: 53.506270589822165\n",
      "From: negeri sembilan To: malay, score: 56.94870448682657\n",
      "From: kelantan To: malay, score: 50.64768195652429\n",
      "From: penang To: malay, score: 62.964413639258034\n",
      "From: melaka To: malay, score: 56.24541676643081\n",
      "From: malay To: johor, score: 54.83246740931249\n",
      "From: malay To: kedah, score: 59.069394967356274\n",
      "From: malay To: pahang, score: 59.695207458023745\n",
      "From: malay To: negeri sembilan, score: 50.69885056697714\n",
      "From: malay To: kelantan, score: 44.66310165425512\n",
      "From: malay To: penang, score: 65.39795752468879\n",
      "From: malay To: melaka, score: 72.39183991789344\n"
     ]
    }
   ],
   "source": [
    "for k, v in pairs.items():\n",
    "    l, r = k.split('<>')\n",
    "    print(f'From: {l} To: {r}, score:', np.mean(v))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "2b70b033",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "56.82057903417684"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = \"\"\"\n",
    "From: johor To: malay, score: 58.2189619529139\n",
    "From: kedah To: malay, score: 59.21260384746205\n",
    "From: pahang To: malay, score: 53.506270589822165\n",
    "From: negeri sembilan To: malay, score: 56.94870448682657\n",
    "From: kelantan To: malay, score: 50.64768195652429\n",
    "From: penang To: malay, score: 62.964413639258034\n",
    "From: melaka To: malay, score: 56.24541676643081\n",
    "\"\"\"\n",
    "\n",
    "scores = []\n",
    "for l in x.split('\\n'):\n",
    "    if 'score:' not in l:\n",
    "        continue\n",
    "    \n",
    "    scores.append(float(l.split('score: ')[1]))\n",
    "    \n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "017c4f84",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "58.10697421407243"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = \"\"\"\n",
    "From: malay To: johor, score: 54.83246740931249\n",
    "From: malay To: kedah, score: 59.069394967356274\n",
    "From: malay To: pahang, score: 59.695207458023745\n",
    "From: malay To: negeri sembilan, score: 50.69885056697714\n",
    "From: malay To: kelantan, score: 44.66310165425512\n",
    "From: malay To: penang, score: 65.39795752468879\n",
    "From: malay To: melaka, score: 72.39183991789344\n",
    "\"\"\"\n",
    "\n",
    "scores = []\n",
    "for l in x.split('\\n'):\n",
    "    if 'score:' not in l:\n",
    "        continue\n",
    "    \n",
    "    scores.append(float(l.split('score: ')[1]))\n",
    "    \n",
    "np.mean(scores)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3bb1c35c",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
